%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%-- Code/solutions for pencil and paper exercises --%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 
% Application to Empirical Data

% By: Ryan Smith, Karl J. Friston, Christopher J. Whyte

% Note to readers: be sure to run sections individually


%% Static perception

clear
close all
rng('default')

% priors
D = [.75 .25]';

% likelihood mapping
A = [.8 .2;
     .2 .8];

% observations
o = [1 0]';

% express generative model in terms of update equations
lns = nat_log(D) + nat_log(A'*o);

% normalize using a softmax function to find posterior
s = (exp(lns)/sum(exp(lns)));

disp('Posterior over states q(s):');
disp(' ');
disp(s);

% Note: Because the natural log of 0 is undefined, for numerical reasons 
% the nat_log function here replaces zero values with very small values. This
% means that the answers generated by this function will vary slightly from
% the exact solutions shown in the text.

return

%% Dynamic perception

clear
close all
rng('default')

% priors
D = [.5 .5]';

% likelihood mapping
A = [.9 .1;
     .1 .9];
 
% transitions
 B = [1 0;
      0 1];

% observations
o{1,1} = [1 0]';
o{1,2} = [0 0]';
o{2,1} = [1 0]';
o{2,2} = [1 0]';

% number of timesteps
T = 2;

% initialise posterior 
for t = 1:T 
    Qs(:,t) = [.5 .5]';
end 

for t = 1:T 
    for tau = 1:T
        % get correct D and B for each time point
        if tau == 1 % first time point
            lnD = nat_log(D);% past
            lnBs = nat_log(B'*Qs(:,tau+1));% future
        elseif tau == T % last time point
             lnBs  = nat_log(B'*Qs(:,tau-1));% no contribution from future
        end 
        % likelihood
        lnAo = nat_log(A'*o{t,tau});
        % update equation
        if tau == 1
            lns = .5*lnD + .5*lnBs + lnAo;
        elseif tau == T
            lns = .5*lnBs + lnAo;
        end 
        % normalize using a softmax function to find posterior
        Qs(:,tau) = (exp(lns)/sum(exp(lns)))
    end 
end

Qs % final posterior beliefs over states

disp('Posterior over states q(s):');
disp(' ');
disp(Qs);

%% functions

% natural log that replaces zero values with very small values for numerical reasons.
function y = nat_log(x)
y = log(x+.01);
end 
